# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import read_base

with read_base():
    from .._base_.datasets.kitti_mono3d import *
    from .._base_.models.pgd import *
    from .._base_.schedules.mmdet_schedule_1x import *
    from .._base_.default_runtime import *

from mmcv.transforms.processing import Resize
from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR

from mmdet3d.datasets.transforms.formating import Pack3DDetInputs
from mmdet3d.datasets.transforms.loading import (LoadAnnotations3D,
                                                 LoadImageFromFileMono3D)
from mmdet3d.datasets.transforms.transforms_3d import RandomFlip3D
from mmdet3d.models.data_preprocessors.data_preprocessor import \
    Det3DDataPreprocessor
from mmdet3d.models.losses.uncertain_smooth_l1_loss import \
    UncertainSmoothL1Loss
from mmdet3d.models.task_modules.coders.pgd_bbox_coder import PGDBBoxCoder

# model settings
model.update(
    dict(
        data_preprocessor=dict(
            type=Det3DDataPreprocessor,
            mean=[103.530, 116.280, 123.675],
            std=[1.0, 1.0, 1.0],
            bgr_to_rgb=False,
            pad_size_divisor=32),
        backbone=dict(frozen_stages=0),
        neck=dict(start_level=0, num_outs=4),
        bbox_head=dict(
            num_classes=3,
            bbox_code_size=7,
            pred_attrs=False,
            pred_velo=False,
            pred_bbox2d=True,
            use_onlyreg_proj=True,
            strides=(4, 8, 16, 32),
            regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 1e8)),
            group_reg_dims=(2, 1, 3, 1, 16,
                            4),  # offset, depth, size, rot, kpts, bbox2d
            reg_branch=(
                (256, ),  # offset
                (256, ),  # depth
                (256, ),  # size
                (256, ),  # rot
                (256, ),  # kpts
                (256, )  # bbox2d
            ),
            centerness_branch=(256, ),
            loss_cls=dict(
                type='mmdet.FocalLoss',
                use_sigmoid=True,
                gamma=2.0,
                alpha=0.25,
                loss_weight=1.0),
            loss_bbox=dict(
                type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
            loss_dir=dict(
                type='mmdet.CrossEntropyLoss',
                use_sigmoid=False,
                loss_weight=1.0),
            loss_centerness=dict(
                type='mmdet.CrossEntropyLoss',
                use_sigmoid=True,
                loss_weight=1.0),
            use_depth_classifier=True,
            depth_branch=(256, ),
            depth_range=(0, 70),
            depth_unit=10,
            division='uniform',
            depth_bins=8,
            pred_keypoints=True,
            weight_dim=1,
            loss_depth=dict(
                type=UncertainSmoothL1Loss,
                alpha=1.0,
                beta=3.0,
                loss_weight=1.0),
            bbox_coder=dict(
                type=PGDBBoxCoder,
                base_depths=((28.01, 16.32), ),
                base_dims=((0.8, 1.73, 0.6), (1.76, 1.73, 0.6), (3.9, 1.56,
                                                                 1.6)),
                code_size=7)),
        # set weight 1.0 for base 7 dims (offset, depth, size, rot)
        # 0.2 for 16-dim keypoint offsets and 1.0 for 4-dim 2D distance targets
        train_cfg=dict(code_weight=[
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
            0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 1.0, 1.0, 1.0,
            1.0
        ]),
        test_cfg=dict(
            nms_pre=100, nms_thr=0.05, score_thr=0.001, max_per_img=20)))

backend_args = None

train_pipeline = [
    dict(type=LoadImageFromFileMono3D, backend_args=backend_args),
    dict(
        type=LoadAnnotations3D,
        with_bbox=True,
        with_label=True,
        with_attr_label=False,
        with_bbox_3d=True,
        with_label_3d=True,
        with_bbox_depth=True),
    dict(type=Resize, scale=(1242, 375), keep_ratio=True),
    dict(type=RandomFlip3D, flip_ratio_bev_horizontal=0.5),
    dict(
        type=Pack3DDetInputs,
        keys=[
            'img', 'gt_bboxes', 'gt_bboxes_labels', 'gt_bboxes_3d',
            'gt_labels_3d', 'centers_2d', 'depths'
        ]),
]
test_pipeline = [
    dict(type=LoadImageFromFileMono3D, backend_args=backend_args),
    dict(type=Resize, scale_factor=1.0),
    dict(type=Pack3DDetInputs, keys=['img'])
]

train_dataloader.update(
    dict(batch_size=3, num_workers=3, dataset=dict(pipeline=train_pipeline)))
test_dataloader.update(dict(dataset=dict(pipeline=test_pipeline)))
val_dataloader.update(dict(dataset=dict(pipeline=test_pipeline)))

# optimizer
optim_wrapper.update(
    dict(
        optimizer=dict(lr=0.001),
        paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.),
        clip_grad=dict(max_norm=35, norm_type=2)))

# learning rate
param_scheduler = [
    dict(
        type=LinearLR, start_factor=1.0 / 3, by_epoch=False, begin=0, end=500),
    dict(
        type=MultiStepLR,
        begin=0,
        end=48,
        by_epoch=True,
        milestones=[32, 44],
        gamma=0.1)
]

train_cfg.update(dict(max_epochs=48, val_interval=2))
auto_scale_lr.update(dict(base_batch_size=12))
